{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Partitioned Training\n",
    "The following example shows how to train a model when the dataset is too large to fit in memory and thus requires partitioning."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "import numpy as np\n",
    "import ampligraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "cebeec31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "from ampligraph.datasets.sqlite_adapter import SQLiteAdapter\n",
    "from ampligraph.datasets.graph_data_loader import GraphDataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH_TO_DATASET = 'your/path/to/dataset/'\n",
    "\n",
    "# Graph loader - loads the data from the file, numpy array, etc and generates batchs for iterating\n",
    "# Internally it will first map raw data to indices and store in db.\n",
    "# then this will map the raw triples to indices and store in another db\n",
    "dataset_loader = GraphDataLoader(PATH_TO_DATASET + 'fb15k-237/train.txt', \n",
    "                                  backend=SQLiteAdapter, # type of backend to use\n",
    "                                  batch_size=1000,       # batch size to use while iterating over this dataset\n",
    "                                  dataset_type='train',  # dataset type\n",
    "                                  use_filter=False,      # Whether to use filter or not\n",
    "                                  use_indexer=True)      # indicates that the data needs to be mapped to index\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "24cb1627",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_split: memory before: 52.366MB, after: 65.027MB, consumed: 12.661MB; exec time: 95.922s\n"
     ]
    }
   ],
   "source": [
    "# Choose the partitioner - in this case we choose BucketGraphPartitioner partitioner\n",
    "from ampligraph.datasets import BucketGraphPartitioner\n",
    "partitioner = BucketGraphPartitioner(dataset_loader, k=3)\n",
    "\n",
    "# The above code will create a partitioner by passing the graph dataloader object\n",
    "# the partitioner will partition the data and will internally create multiple graph \n",
    "# data loaders for each partition.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "badb504e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/alberto.bernardi/miniforge3/lib/python3.10/site-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "277/277 [==============================] - 70s 253ms/step - loss: 1083.1062\n",
      "Epoch 2/10\n",
      "277/277 [==============================] - 38s 137ms/step - loss: 1082.0710\n",
      "Epoch 3/10\n",
      "277/277 [==============================] - 28s 101ms/step - loss: 1071.3022\n",
      "Epoch 4/10\n",
      "277/277 [==============================] - 28s 101ms/step - loss: 1034.2253\n",
      "Epoch 5/10\n",
      "277/277 [==============================] - 37s 135ms/step - loss: 971.5778\n",
      "Epoch 6/10\n",
      "277/277 [==============================] - 48s 173ms/step - loss: 900.3652\n",
      "Epoch 7/10\n",
      "277/277 [==============================] - 58s 211ms/step - loss: 832.1041\n",
      "Epoch 8/10\n",
      "277/277 [==============================] - 110s 398ms/step - loss: 770.6450\n",
      "Epoch 9/10\n",
      "277/277 [==============================] - 163s 589ms/step - loss: 716.6290\n",
      "Epoch 10/10\n",
      "277/277 [==============================] - 165s 597ms/step - loss: 669.4027\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x28d5002e0>"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# create and compile a model as usual\n",
    "partitioned_model = ScoringBasedEmbeddingModel(eta=2, \n",
    "                                               k=50, \n",
    "                                               scoring_type='DistMult')\n",
    "\n",
    "partitioned_model.compile(optimizer='adam', loss='multiclass_nll')\n",
    "\n",
    "partitioned_model.fit(partitioner,            # Pass the partitioner object as input to the fit function\n",
    "                                              # this will generate data for the model during training\n",
    "                                              # No need to pass partitioning_k parameter as this will be \n",
    "                                              # overridden by partitioner_k of input partitioner\n",
    "                      batch_size=1000,        # Batch size\n",
    "                      epochs=10)              # number of epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f8d55c73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "28 triples containing invalid keys skipped!\n"
     ]
    }
   ],
   "source": [
    "# Create an instance of graph(triple) loader for test set by passing sql backend. \n",
    "# the data will be indexed using the models training indexer\n",
    "# and the indexed triples will be stored in a database in chunks\n",
    "dataset_loader_test = GraphDataLoader(PATH_TO_DATASET + 'fb15k-237/test.txt',\n",
    "                                      backend=SQLiteAdapter,                         # Type of backend to use\n",
    "                                      batch_size=400,                                # Batch size to use while iterating over this dataset\n",
    "                                      dataset_type='test',                           # Dataset type\n",
    "                                      use_indexer=partitioned_model.data_indexer)    # Get the data_indexer from the trained model \n",
    "                                                                                     # and map the concepts to the same indices used at training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "10f32d1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "53/53 [==============================] - 27s 504ms/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(672.3955377238477, 0.0827018782039032, 0.0, 0.22237988061454153, 20438)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ranks = partitioned_model.evaluate(dataset_loader_test, # pass the dataloader object as input to the \n",
    "                                                        # evaluate function. this will generate data\n",
    "                                                        # for the model during training\n",
    "                                   batch_size=400)\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "f65ef7f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING - Found untraced functions such as _get_ranks while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import save_model\n",
    "save_model(model=partitioned_model, model_name_path='./partitioned_model_bucket')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "5813a62e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ampligraph.utils import restore_model\n",
    "model = restore_model('./partitioned_model_bucket')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "c77a8342",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "28 triples containing invalid keys skipped!\n"
     ]
    }
   ],
   "source": [
    "dataset_loader_test = GraphDataLoader(PATH_TO_DATASET + 'fb15k-237/test.txt', \n",
    "                                        backend=SQLiteAdapter,          # type of backend to use\n",
    "                                        batch_size=400,                 # batch size to use while iterating over this dataset\n",
    "                                        dataset_type='test',            # dataset type\n",
    "                                        use_indexer=model.data_indexer)    # get the mapper from the trained model \n",
    "                                                                                                    # and map the concepts to same indices \n",
    "                                                                                                    # as used during training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "c21abb55",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "53/53 [==============================] - 27s 506ms/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(672.3955377238477, 0.0827018782039032, 0.0, 0.22237988061454153, 20438)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ranks = model.evaluate(dataset_loader_test, # pass the dataloader object as input to the \n",
    "                                            # evaluate function. this will generate data\n",
    "                                            # for the model during training\n",
    "                       batch_size=400)\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
